{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def box_iou(box1, box2, type=\"absolute\"):\n",
    "    # box1: numpy of shape [4,]\n",
    "    # box2: numpy of shape [n, 4]\n",
    "    box1 = np.expand_dims(box1, 0)\n",
    "    if len(box2.shape) == 1:\n",
    "        box2 = np.expand_dims(box2, 0)\n",
    "    if type == \"absolute\":\n",
    "        x11, y11, w1, h1 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]\n",
    "        x21, y21, w2, h2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]\n",
    "        x12, y12 = x11 + w1, y11 + h1\n",
    "        x22, y22 = x21 + w2, y21 + h2\n",
    "    elif type == \"relative\":\n",
    "        w1, h1 = box1[:,2], box1[:,3]\n",
    "        w2, h2 = box2[:,2], box2[:,3]\n",
    "        x12, y12 = w1, h1\n",
    "        x22, y22 = w2, h2\n",
    "        x11, y11, x21, y21 = 0, 0, 0, 0\n",
    "\n",
    "    center_x1, center_y1 = np.maximum(x11, x21), np.maximum(y11, y21)\n",
    "    center_x2, center_y2 = np.minimum(x12, x22), np.minimum(y12, y22)\n",
    "    intersection = np.maximum(0, center_x2 - center_x1) * np.maximum(0, center_y2 - center_y1)\n",
    "    return intersection / (w1 * h1 + w2 * h2 - intersection + 1e-10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.25, 0.25, 0.25])"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test 1\n",
    "box1 = np.array([2,2,2,2])\n",
    "box2 = np.array([[0,0,4,4],\n",
    "               [0,0,4,4],\n",
    "               [0,0,4,4]])\n",
    "box_iou(box1, box2, \"relative\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.25])"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test 2\n",
    "box1 = np.array([2,2,2,2])\n",
    "box2 = np.array([0,0,4,4])\n",
    "box_iou(box1, box2, \"relative\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.25      , 0.14285714, 0.        ])"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test 3\n",
    "box1 = np.array([2,2,2,2])\n",
    "box2 = np.array([[0,0,4,4],\n",
    "               [2,0,4,3],\n",
    "               [6,6,4,4]])\n",
    "box_iou(box1, box2, \"absolute\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def box_iou(box1, box2, type=\"absolute\"):\n",
    "    # box1: numpy of shape [n, 4]\n",
    "    # box2: numpy of shape [n, 4]\n",
    "    if type == \"absolute\":\n",
    "        x11, y11, w1, h1 = box1\n",
    "        x21, y21, w2, h2 = box2\n",
    "        x12, y12 = x11 + w1, y11 + h1\n",
    "        x22, y22 = x21 + w2, y21 + h2\n",
    "    elif type == \"relative\":\n",
    "        _, _, w1, h1 = box1\n",
    "        _, _, w2, h2 = box2\n",
    "        x12, y12 = w1, h1\n",
    "        x22, y22 = w2, h2\n",
    "        x11, y11, x21, y21 = 0, 0, 0, 0\n",
    "\n",
    "    if w1 * h1 + w2 * h2 == 0:\n",
    "        return 0\n",
    "    else:\n",
    "        center_x1, center_y1 = max(x11, x21), max(y11, y21)\n",
    "        center_x2, center_y2 = min(x12, x22), min(y12, y22)\n",
    "        intersection = max(0, center_x2 - center_x1) * max(0, center_y2 - center_y1)\n",
    "        return intersection / (w1 * h1 + w2 * h2 - intersection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.25"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "box1 = (2,2,2,2) \n",
    "box2 =(0,0,4,4)\n",
    "box_iou(box1, box2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "box1 = np.array([[3,3,4,5]])\n",
    "box2 = np.array([[3,3,4,5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "x11, y11, w1, h1 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]\n",
    "x21, y21, w2, h2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate testcase\n",
    "import cv2\n",
    "import json\n",
    "\n",
    "\n",
    "annotation_file = \"/Users/thisiszhou/Documents/dataset/\" + \\\n",
    "                    \"coco/annotations/instances_val2014.json\"\n",
    "photo_folder = \"/Users/thisiszhou/Documents/dataset/coco/val2014\"\n",
    "\n",
    "def json_load(json_file):\n",
    "    with open(json_file, \"r\") as f:\n",
    "        return json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data=json_load(annotation_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'segmentation': [[302.66, 105.79, 313.26, 110.61, 321.94, 111.57, 323.86, 111.57, 340.25, 117.35, 364.35, 131.81, 383.63, 153.98, 391.34, 170.37, 386.52, 194.46, 364.35, 217.6, 361.46, 221.45, 347.0, 241.69, 346.03, 262.9, 347.96, 304.35, 347.0, 328.44, 349.89, 347.72, 338.32, 354.47, 313.26, 354.47, 306.51, 352.54, 306.51, 349.65, 316.15, 341.94, 319.05, 327.48, 316.15, 293.74, 304.59, 269.65, 307.48, 270.61, 307.48, 267.72, 306.51, 255.19, 306.51, 229.16, 312.3, 212.78, 313.26, 192.54, 318.08, 182.9, 284.35, 173.26, 268.92, 164.58, 260.25, 160.73, 246.75, 187.72, 235.19, 199.28, 219.77, 217.6, 216.87, 236.87, 213.02, 246.51, 193.74, 251.33, 174.46, 247.48, 171.57, 245.55, 187.96, 233.02, 201.45, 217.6, 212.05, 195.43, 230.37, 154.94, 232.3, 145.31, 240.97, 130.85, 240.97, 125.06, 236.15, 105.79, 235.19, 83.62, 254.47, 73.98, 270.85, 78.8, 280.49, 91.33]], 'area': 20284.18945, 'iscrowd': 0, 'image_id': 295957, 'bbox': [171.57, 73.98, 219.77, 280.49], 'category_id': 1, 'id': 454970}\n",
      "{'segmentation': [[330.09, 356.95, 275.39, 343.52, 245.64, 344.48, 222.62, 346.4, 207.26, 349.28, 207.26, 355.99, 260.04, 368.47, 305.14, 368.47, 334.88, 359.83]], 'area': 2162.9236000000033, 'iscrowd': 0, 'image_id': 295957, 'bbox': [207.26, 343.52, 127.62, 24.95], 'category_id': 42, 'id': 651893}\n"
     ]
    }
   ],
   "source": [
    "for anno in data[\"annotations\"]:\n",
    "    if anno['image_id'] == 295957:\n",
    "        print(anno)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IOU: 61643.28730000001\n",
      "pre 0 0 220 281\n",
      "label: 171 73 391 354\n",
      "mask_label_in_pre_box shape: (381, 320)\n",
      "mask_label: (427, 640)\n"
     ]
    }
   ],
   "source": [
    "import tongyong as ty\n",
    "import numpy as np\n",
    "import cv2\n",
    "\n",
    "image_size = 427, 640\n",
    "polygon = [302.66, 105.79, 313.26, 110.61, 321.94, 111.57, 323.86, 111.57, 340.25, 117.35, 364.35, 131.81, 383.63, 153.98, 391.34, 170.37, 386.52, 194.46, 364.35, 217.6, 361.46, 221.45, 347.0, 241.69, 346.03, 262.9, 347.96, 304.35, 347.0, 328.44, 349.89, 347.72, 338.32, 354.47, 313.26, 354.47, 306.51, 352.54, 306.51, 349.65, 316.15, 341.94, 319.05, 327.48, 316.15, 293.74, 304.59, 269.65, 307.48, 270.61, 307.48, 267.72, 306.51, 255.19, 306.51, 229.16, 312.3, 212.78, 313.26, 192.54, 318.08, 182.9, 284.35, 173.26, 268.92, 164.58, 260.25, 160.73, 246.75, 187.72, 235.19, 199.28, 219.77, 217.6, 216.87, 236.87, 213.02, 246.51, 193.74, 251.33, 174.46, 247.48, 171.57, 245.55, 187.96, 233.02, 201.45, 217.6, 212.05, 195.43, 230.37, 154.94, 232.3, 145.31, 240.97, 130.85, 240.97, 125.06, 236.15, 105.79, 235.19, 83.62, 254.47, 73.98, 270.85, 78.8, 280.49, 91.33]\n",
    "mask_label = ty.polygon2mask2(image_size, [polygon])\n",
    "\n",
    "out_mask = generate_label_for_one_ob(42, box_label=[171.57, 73.98, 219.77, 280.49], mask_label=mask_label,\n",
    "                                     box_pre=[171.57, 73.98, 219.77+100, 280.49+100],\n",
    "                                     mask_pre=np.zeros([50,50]), reg_label_encode=None, mask_label_encode=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_mask = 255 - out_mask*255\n",
    "mask_label = 255 - mask_label*255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cv2.imwrite(\"label_mask.jpg\", mask_label)\n",
    "cv2.imwrite(\"test_pre_mask.jpg\", out_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_label_for_one_ob(cate_label, box_label, mask_label, box_pre, mask_pre,\n",
    "                                  reg_label_encode, mask_label_encode):\n",
    "    # functions: cal ob Iou, cal mask Iou, cal x,y,w,h of label after encode\n",
    "    def cut_mask():\n",
    "        mask_pre_shape = mask_pre.shape\n",
    "        x11, y11, w1, h1 = box_label\n",
    "        x21, y21, w2, h2 = box_pre\n",
    "        x12, y12 = x11 + w1, y11 + h1\n",
    "        x22, y22 = x21 + w2, y21 + h2\n",
    "        center_x1, center_y1 = np.maximum(x11, x21), np.maximum(y11, y21)\n",
    "        center_x2, center_y2 = np.minimum(x12, x22), np.minimum(y12, y22)\n",
    "        intersection = np.maximum(0, center_x2 - center_x1) * np.maximum(0, center_y2 - center_y1)\n",
    "        print(\"IOU:\", intersection)\n",
    "        if intersection <= 0:\n",
    "            return np.zeros(mask_pre_shape, dtype=np.float32)\n",
    "        else:\n",
    "            mask_label_in_pre_box = np.zeros([int(h2) + 1, int(w2) + 1],dtype=np.float32)\n",
    "            pre_by1, pre_by2 = int(center_y1 - y21), int(center_y2 - y21) + 1\n",
    "            pre_bx1, pre_bx2 = int(center_x1 - x21), int(center_x2 - x21) + 1\n",
    "            print(\"pre\", pre_bx1, pre_by1, pre_bx2, pre_by2)\n",
    "            label_y1, label_y2 = int(center_y1), int(center_y1) + pre_by2 - pre_by1\n",
    "            label_x1, label_x2 = int(center_x1), int(center_x1) + pre_bx2 - pre_bx1\n",
    "            print(\"label:\", label_x1, label_y1, label_x2, label_y2)\n",
    "            print(\"mask_label_in_pre_box shape:\", mask_label_in_pre_box.shape)\n",
    "            print(\"mask_label:\", mask_label.shape)\n",
    "            mask_label_in_pre_box[pre_by1: pre_by2, pre_bx1: pre_bx2] = mask_label[label_y1: label_y2, label_x1: label_x2]\n",
    "            #cv2.imwrite(\"mask_label_in_box.jpg\", mask_label[label_y1: label_y2, label_x1: label_x2]*255)\n",
    "            mask_label_in_pre_box = cv2.resize(mask_label_in_pre_box, mask_pre_shape)\n",
    "\n",
    "            return mask_label_in_pre_box\n",
    "    return cut_mask()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# part test\n",
    "import tongyong as ty\n",
    "import numpy as np\n",
    "import cv2\n",
    "\n",
    "image_size = 427, 640\n",
    "polygon = [302.66, 105.79, 313.26, 110.61, 321.94, 111.57, 323.86, 111.57, 340.25, 117.35, 364.35, 131.81, 383.63, 153.98, 391.34, 170.37, 386.52, 194.46, 364.35, 217.6, 361.46, 221.45, 347.0, 241.69, 346.03, 262.9, 347.96, 304.35, 347.0, 328.44, 349.89, 347.72, 338.32, 354.47, 313.26, 354.47, 306.51, 352.54, 306.51, 349.65, 316.15, 341.94, 319.05, 327.48, 316.15, 293.74, 304.59, 269.65, 307.48, 270.61, 307.48, 267.72, 306.51, 255.19, 306.51, 229.16, 312.3, 212.78, 313.26, 192.54, 318.08, 182.9, 284.35, 173.26, 268.92, 164.58, 260.25, 160.73, 246.75, 187.72, 235.19, 199.28, 219.77, 217.6, 216.87, 236.87, 213.02, 246.51, 193.74, 251.33, 174.46, 247.48, 171.57, 245.55, 187.96, 233.02, 201.45, 217.6, 212.05, 195.43, 230.37, 154.94, 232.3, 145.31, 240.97, 130.85, 240.97, 125.06, 236.15, 105.79, 235.19, 83.62, 254.47, 73.98, 270.85, 78.8, 280.49, 91.33]\n",
    "mask_label = ty.polygon2mask2(image_size, [polygon])\n",
    "label_x1, label_y1, label_x2, label_y2 = 207, 343, 335, 368"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "a=np.array([[1,0],\n",
    "           [2,3]], dtype=np.float32)\n",
    "np.resize(a, [3,3])\n",
    "import cv2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [
    {
     "ename": "SystemError",
     "evalue": "new style getargs format but argument is not a tuple",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mSystemError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-129-3aa756823652>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcv2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mSystemError\u001b[0m: new style getargs format but argument is not a tuple"
     ]
    }
   ],
   "source": [
    "cv2.resize(a, [3,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0., 2.],\n",
       "       [3., 1., 0.],\n",
       "       [2., 3., 1.]], dtype=float32)"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.resize(a, [3,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tuple"
      ]
     },
     "execution_count": 131,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(a.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'tongyong'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-132-6403f4e46307>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtongyong\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tongyong'"
     ]
    }
   ],
   "source": [
    "import tongyong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'VIRTUALENVWRAPPER_SCRIPT': '/usr/local/bin/virtualenvwrapper.sh',\n",
       " 'VIRTUALENVWRAPPER_PROJECT_FILENAME': '.project',\n",
       " 'TERM_PROGRAM': 'Apple_Terminal',\n",
       " 'SHELL': '/bin/bash',\n",
       " 'TERM': 'xterm-color',\n",
       " 'TMPDIR': '/var/folders/s4/5lzw205d44g9zrfql0xbf2d40000gn/T/',\n",
       " 'Apple_PubSub_Socket_Render': '/private/tmp/com.apple.launchd.x2O0mfIl7P/Render',\n",
       " 'TERM_PROGRAM_VERSION': '421.2',\n",
       " 'TERM_SESSION_ID': '8D0AAE18-AE4B-4A46-91BB-FB951CBFD33F',\n",
       " 'USER': 'thisiszhou',\n",
       " 'SSH_AUTH_SOCK': '/private/tmp/com.apple.launchd.Eat0w1a5Qz/Listeners',\n",
       " 'VIRTUAL_ENV': '/Users/thisiszhou/.virtualenvs/learn',\n",
       " 'WORKON_HOME': '/Users/thisiszhou/.virtualenvs',\n",
       " 'VIRTUALENVWRAPPER_PYTHON': '/usr/local/bin/python3',\n",
       " 'PATH': '/Users/thisiszhou/.virtualenvs/learn/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/Library/Frameworks/Python.framework/Versions/3.6/bin',\n",
       " 'VIRTUALENVWRAPPER_HOOK_DIR': '/Users/thisiszhou/.virtualenvs',\n",
       " 'PWD': '/Users/thisiszhou',\n",
       " 'LANG': 'zh_CN.UTF-8',\n",
       " 'XPC_FLAGS': '0x0',\n",
       " 'PS1': '(learn) \\\\h:\\\\W \\\\u\\\\$ ',\n",
       " 'XPC_SERVICE_NAME': '0',\n",
       " 'SHLVL': '1',\n",
       " 'HOME': '/Users/thisiszhou',\n",
       " 'LOGNAME': 'thisiszhou',\n",
       " 'VIRTUALENVWRAPPER_WORKON_CD': '1',\n",
       " '_': '/Users/thisiszhou/.virtualenvs/learn/bin/jupyter',\n",
       " '__CF_USER_TEXT_ENCODING': '0x1F5:0x19:0x34',\n",
       " 'KERNEL_LAUNCH_TIMEOUT': '40',\n",
       " 'JPY_PARENT_PID': '610',\n",
       " 'CLICOLOR': '1',\n",
       " 'PAGER': 'cat',\n",
       " 'GIT_PAGER': 'cat',\n",
       " 'MPLBACKEND': 'module://ipykernel.pylab.backend_inline'}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mask iou\n",
    "mask1 = np.zeros([10,10])\n",
    "mask2 = np.ones([10,10])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "inter_fenzi = np.ones([10,10])\n",
    "inter_fenmu = np.zeros([10,10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "inter_fenmu[mask1>0]=1\n",
    "inter_fenmu[mask2>0]=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "inter_fenzi[mask2==0]=0\n",
    "inter_fenzi[mask1==0]=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mask_IoU(mask1, mask2):\n",
    "    assert mask1.shape == mask2.shape\n",
    "    shape = mask1.shape\n",
    "    point_sum = np.zeros(shape)\n",
    "    point_inter = np.ones(shape)\n",
    "    \n",
    "    point_sum[mask1 >= 0.5]=1\n",
    "    point_sum[mask2 >= 0.5]=1\n",
    "    point_inter[mask1 < 0.5]=0\n",
    "    point_inter[mask2 < 0.5]=0\n",
    "    \n",
    "    print(\"point_sum:\", point_sum)\n",
    "    print(\"point_inter:\", point_inter)\n",
    "    return np.sum(point_inter) / (np.sum(point_sum) + 1e-10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask1 = np.array([[0,0,0,0,0,0],\n",
    "                  [0,0,0,0,0,0],\n",
    "                  [0,0,1,1,0,0],\n",
    "                  [0,0,0,1,0,0],\n",
    "                  [0,0,0,0,0,0]\n",
    "                 ])\n",
    "mask2 = np.array([[0,0,0,0,0,0],\n",
    "                  [0,0,0,0,0,0],\n",
    "                  [0,0,1,1,0,0],\n",
    "                  [0,0,1,1,0,0],\n",
    "                  [0,0,0,0,0,0]\n",
    "                 ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "point_sum: [[0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 1. 0. 0.]\n",
      " [0. 0. 1. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0.]]\n",
      "point_inter: [[0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 1. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0.]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.74999999998125"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask_IoU(mask1, mask2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
